import numpy as np
import gymnasium as gym
import os
from gymnasium import spaces
import robosuite
import imageio
import copy
import cv2
from Environment.environment import Environment, EnvObject, Action, Goal, strip_instance
from Environment.Environments.AirHockey.air_hockey_specs import air_hockey_variants, define_object_dicts
from ACState.object_dict import ObjDict
from Record.file_management import numpy_factored, display_frame
from collections import deque
from airhockey import AirHockeyEnv, renderers
import yaml


class AirGoal(Goal):
    def __init__(self, **kwargs):
        self.name = "Goal"
        self.attribute = np.ones(3)
        self.interaction_trace = list()
        self.target_idx = -4
        self.partial = 2 # should be the length of the indices to use as goals, SET IN SUBCLASS
        self.bounds = kwargs["bounds"]
        self.bounds_lower = self.bounds[0]
        self.range = self.bounds[1] - self.bounds[0]
        self.all_names = kwargs["all_names"]
        self.goal_epsilon = kwargs["goal_epsilon"]
        self.env = kwargs["env"]

    def generate_bounds(self):
        return self.bounds_lower[:self.partial], self.bounds[1]

    def sample_goal(self):
        # TODO: shouldn't be called
        goal = np.random.rand(3) * self.range + self.bounds_lower
        # if obstacle_pos is None: return goal
        # while np.min(np.linalg.norm(goal - obstacle_pos, axis=-1)) < self.obstacle_radius:
        #     goal = np.random.rand(3) * self.range + self.bounds_lower
        return goal
    
    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        return object_state[...,self.target_idx,:self.partial]

    def add_interaction(self, reached_goal):
        if reached_goal:
            self.interaction_trace += ["Target"]

    def get_state(self):
        return self.attribute # np.array([self.goal_epsilon])
    
    def set_state(self, goal=None):
        if goal is not None: self.attribute = goal
        return self.attribute # np.array([self.goal_epsilon])
    
    def set_goal_epsilon(self, goal_epsilon):
        self.env.base_goal_radius = goal_epsilon
        self.env.goal_radius = goal_epsilon
        return super().set_goal_epsilon(goal_epsilon)

    def check_goal(self, env):
        # returns True if all dimensions are less than epsilon
        return np.linalg.norm(self.get_achieved_goal(env) - self.attribute) < self.goal_epsilon

class AirObject(EnvObject):
    def __init__(self, name, env, bounds):
        super().__init__(name)
        self.env = env
        self.state = None
        self.bounds = bounds
    
    def set_state(self, state=None):
        if state is None:
            if self.state is not None:
                return self.state
            else:
                self.state = self.env.env.get_obs()[self.bounds[0]:self.bounds[1]]
                return self.state
        else:
            self.state = state
            return self.state


def init_from_config(config_pth, horizon=-1, seed=-1):
    air_hockey_cfg_fp = config_pth
    
    with open(air_hockey_cfg_fp, 'r') as f:
        air_hockey_cfg = yaml.safe_load(f)

    air_hockey_params = air_hockey_cfg['air_hockey']
    air_hockey_params["max_timesteps"] = horizon if horizon > 0 else 100000 # a large number for no max timesteps
    air_hockey_params['n_training_steps'] = air_hockey_cfg['n_training_steps']
    
    if 'sac' == air_hockey_cfg['algorithm']:
        if 'goal' in air_hockey_cfg['air_hockey']['task']:
            air_hockey_cfg['air_hockey']['return_goal_obs'] = True
        else:
            air_hockey_cfg['air_hockey']['return_goal_obs'] = False
    else:
        air_hockey_cfg['air_hockey']['return_goal_obs'] = False
    
    air_hockey_params_cp = air_hockey_params.copy()
    if seed < 0: seed = np.random.randint(10000)
    air_hockey_params_cp['seed'] = seed
    
    eval_env = AirHockeyEnv(air_hockey_params_cp)
    return eval_env, air_hockey_params_cp


class RobotAirHockey(Environment):
    def __init__(self, variant="default", horizon=30, renderable=False, fixed_limits=False, flat_obs=False, append_id=False, frameskip=1, seed=-1):
        super().__init__()
        self.self_reset = True
        self.fixed_limits = fixed_limits
        self.variant=variant
        self.config_pth = air_hockey_variants[variant]
        self.backend = "robosuite" if variant.find("robosuite") != -1 else "box2d"
        self.horizon = horizon
        self.env, self.args = init_from_config(self.config_pth, horizon, seed=seed)
        self.args = ObjDict(self.args)
        self.num_obstacles = self.args.num_blocks
        # environment properties
        self.num_actions = -1 # this must be defined, -1 for continuous. Only needed for primitive actions
        self.name = "AirHockey" # required for an environment 
        self.discrete_mode = False # TODO: support for discrete actions not yet implemented
        self.discrete_actions = self.discrete_mode
        self.frameskip = frameskip
        self.timeout_penalty = -horizon
        self.pos_size = 2

        # spaces
        ranges, dynamics, position_masks, instanced = define_object_dicts(self.env)
        self.action_shape = (2,)
        self.action_space = spaces.Box(low=-np.ones(2), high=np.ones(2))
        self.action = Action(not self.discrete_actions, self.action_shape[0])
        self.renderable = renderable
        self.renderer = renderers.AirHockeyRenderer(self.env) if renderable else None
        self.flat_obs = flat_obs
        self.append_id = append_id

        # running values
        self.itr = 0
        self.total_itr = 0
        self.total_reward = 0
        self.non_passive_trace_count = 0

        # state components
        self.extracted_state = None

        # factorized state properties
        # TODO: no distinction between blocks and obstacles at the moment
        obs_names = ["Obstacle"] if self.num_obstacles > 0 else []
        self.object_names = ["Action", "Paddle"] + obs_names + ["Target" , "Goal", 'Done', "Reward"] # must be initialized, a list of names that controls the ordering of things
        self.object_sim_names = {"Action": "", "Paddle": "paddle_ego", "Target": "puck_0", "Goal": "", "Done": "", "Reward": ""}
        self.object_sim_names = {**self.object_sim_names, **{"Obstacle" + str(i): "block_" + str(i) for i in range(self.num_obstacles)}}
        self.sim_object_names = {value: key for key,value in self.object_sim_names.items() if len(value) > 0}
        # TODO: only box2d supported at the moment, though robosuite should be addable
        # self.object_obs_names = {"Action": "", "Gripper": "robot0_eef_pos", "Target": "cube_pos", "Goal": "goal_pos", "Done": "", "Reward": ""}
        # self.sim_object_names = {name: simname for (simname, name) in self.object_sim_names.items() if len(simname) > 0}
        self.object_sizes = {"Action": 2, "Paddle": 4, "Target": 4,"Goal": 2, "Done": 1, "Reward": 1} # must be initialized, a dictionary of name to length of the state
        if self.num_obstacles > 0: self.object_sizes["Obstacle"] = 2
        self.flat_indices = []
        ranges_fixed, dynamics_fixed = copy.deepcopy(ranges), copy.deepcopy(dynamics) # TODO: implement fixed ranges and dynamics
        self.object_range = ranges if not self.fixed_limits else ranges_fixed # the minimum and maximum values for a given feature of an object
        self.object_dynamics = dynamics if not self.fixed_limits else dynamics_fixed
        self.object_range_true = ranges
        self.object_dynamics_true = dynamics
        self.position_masks = position_masks

        # TODO: multiobject support not currently implemented
        self.object_instanced = instanced
        self.all_names = sum([[(name + str(i) if instanced[name] > 1 else name) for i in range(instanced[name])] for name in self.object_names], start = [])
        self.num_objects = len(self.all_names) # 4 for gripper and Target Goal, 2 for relative gripper Target Goal
        self.instance_length = len(self.all_names)
        self.object_name_dict = dict()
        len_up_to = 0
        for i in range(len(self.all_names)):
            name = self.all_names[i]
            if name not in ["Action", "Reward", "Done", "Goal"]:
                self.object_name_dict[name] = AirObject(self.all_names[i], self, 
                                            (len_up_to, len_up_to + self.object_sizes[strip_instance(name)]))
                len_up_to += self.object_sizes[strip_instance(name)]
        self.goal = AirGoal(all_names=self.all_names, goal_epsilon = self.args.base_goal_radius, bounds=ranges["Goal"], env=self.env)
        self.goal_space = spaces.Box(low=ranges["Goal"][0], high=ranges["Goal"][1])
        self.object_name_dict = {**{"Action": self.action, "Reward": self.reward, "Done": self.done, "Goal": self.goal}, **self.object_name_dict}
        self.objects = [self.object_name_dict[n] for n in self.all_names]
        self.obstacles = [n for n in self.all_names if n.find("Obstacle") != -1]
        self.valid_names = self.all_names
        self.passive_trace = np.eye(len(self.all_names))
        self.passive_trace[1,0] = 1
        self.trace_graph = np.zeros((self.num_objects, self.num_objects))

        # position mask
        self.pos_size = 3
        self.length, self.width = self.env.length, self.env.width

        obs = self.reset()
        self.trace = self.get_full_current_trace()

        self.reward_collect = 0
        if self.flat_obs: self.observation_space = spaces.Box(low=-1, high=1, shape=self.reset().shape)
        else: self.observation_space = spaces.Box(low=-1, high=1, shape=[9])

    def set_named_state(self, obs_dict, set_objects=True):
        if set_objects:
            for name in self.all_names:
                if name not in ["Action", "Reward", "Done"]:
                    self.object_name_dict[name].set_state(obs_dict[name])
                else:
                    if name in obs_dict:
                        self.object_name_dict[name].attribute = obs_dict[name]


    def step(self, action, render=False): # render will NOT change renderable, so it will still render or not render
        # step internal robosuite environment
        self.reset_traces()
        next_obs, reward, term, trunc, info = self.env.step(action)
        next_obs = self.obs_to_dict(next_obs)
        # print(self.reward_collect, next_obs["cube_pos"], next_obs["robot0_eef_pos"])
        done = term | trunc
        info["TimeLimit.truncated"] = trunc
        # set state
        next_obs["Action"], next_obs["Done"], next_obs["Reward"] = action, done, reward.squeeze()

        self.set_named_state(next_obs) # sets the state objects
        rew = (self.goal.check_goal(self)).squeeze() # in reality, use the goal conditioned reward
        self.reward_collect += rew
        self.reward.attribute = rew

        # handle rendering
        self.frame = self.renderer.get_frame() if self.renderable else None

        # handle specialized values
        self.assign_traces()
        self.trace = self.get_factor_graph(complete_graph=True)
        info["trace"] = self.trace
        self.factor_graph = self.get_factor_graph(complete_graph=False)
        self.trace_graph += self.trace

        full_state =  self.get_state()

        # step timers 
        self.itr += 1
        self.total_itr += 1
    
        if np.linalg.norm(self.trace - self.passive_trace) != 0: 
            self.non_passive_trace_count += 1
        # if self.total_itr % 1000 == 0:
        #     print("non_passive_trace_frequency ", self.non_passive_trace_count / self.total_itr)
        #     print("total reward ", self.total_reward / (self.total_itr) * self.max_steps)
        #     print("trace_graph ", self.trace_graph / self.total_itr)


        if self.done.attribute:
            self.reset()
            self.itr = 0
        info = self.get_info() # TODO: info from env is too populated to be useful
        # print("step",self.env, np.array([obs['factored_state']["Obstacle" + str(i)] for i in range(15)]))
        return full_state, self.reward.attribute, self.done.attribute, info

    def get_state(self, render=False):
        factored_state = {obj.name: obj.get_state() for obj in self.objects}
        factored_state["VALID_NAMES"] = self.valid_binary(self.valid_names)
        factored_state["TRACE"] = self.trace
        if self.flat_obs:
            return np.concatenate([obj.get_state() for obj in self.objects])
        return {"raw_state": self.frame, "factored_state": factored_state}

    def get_info(self, info=None):
        if "Goal" in self.object_name_dict: achieved_goal, desired_goal, success = self.goal.get_achieved_goal(self), self.goal.get_state(), self.goal.check_goal(self)
        new_info = {"TimeLimit.truncated": self.done.attribute, "trace": self.trace, "factor_graph": self.factor_graph, "valid": self.valid_binary(self.valid_names), "achieved_goal": achieved_goal, "desired_goal": desired_goal, "success": success}
        if info is not None: return {**new_info, **info}
        return new_info

    def assign_traces(self):
        object_contacts, contact_names = self.env.simulator.get_contacts()
        for i, name in enumerate(self.all_names):
            if name == "Paddle":
                self.object_name_dict[name].interaction_trace.append("Action")
            elif name in ["Action", "Done", "Reward", "Goal"]:
                continue
            else: # only the puck right now
                this_contacts = contact_names[self.object_sim_names[name]]
                contacts = [self.sim_object_names[n] for n in this_contacts]
                # this_contacts = object_contacts[i-1]
                # contacts = [self.all_names[cidx] for cidx, c in enumerate(this_contacts) if c == 1]
                self.object_name_dict[name].interaction_trace = contacts  

    def reset_traces(self):
        for name in self.all_names:
            self.object_name_dict[name].interaction_trace = list()
    
    def obs_to_dict(self, obs):
        # converts an observtion into a dict, TODO only supports vel, many_blocks_vel
        # and paddle-puck object environments
        if self.args.obs_type == "vel":
            return {"Paddle": obs[:4], "Target": obs[4:8], "Goal": obs[8:]}
        elif self.args.obs_type == "many_blocks_vel":
            odim = self.object_sizes["Obstacle"]
            return {**{"Paddle": obs[:4], "Target": obs[4:8], "Goal": obs[8 + self.num_obstacles * odim:]}, 
                    **{"Obstacle" + str(i): obs[8 + i * odim: 8 + odim + i *odim] for i in range(self.num_obstacles)}}
        else:
            raise NotImplementedError("Obs type " + self.args.obs_type + " not supported")

    def set_sim_objects(self):
        if self.backend == "box2d":
            self.sim_objects = {"Action": None, "Paddle": self.env.simulator.paddles['paddle_ego'], "Target": self.env.simulator.pucks['puck_0'], "Goal": "", "Done": "", "Reward": ""}
            self.sim_objects = {**self.sim_objects, **{"Obstacle" + str(i): self.env.simulator.blocks['block_'+ str(i)] for i in range(self.num_obstacles)}}
        else:
            raise NotImplementedError("Backend "+ self.backend + " not implemented")


    def reset(self, goal=None, **kwargs):
        obs, info = self.env.reset()
        self.set_sim_objects()
        obs = self.obs_to_dict(obs)
        self.valid_names = self.all_names # TODO: implement valid names on resets
        self.assign_traces()
        self.trace = self.get_factor_graph(complete_graph=True)
        self.factor_graph = self.get_factor_graph(complete_graph=False)

        self.reward_collect = 0
        self.set_named_state(obs)
        self.frame = self.renderer.get_frame() if self.renderable else None

        # reset should handle goal resampling, so just take the new environment goal
        if goal is None:
            self.goal.set_state(self.env.goal_pos)
        else:
            self.goal.set_state(goal)
            self.env.goal_pos = goal
        # obstacle_pos = None if self.num_obstacles == 0 else np.stack([self.object_name_dict[name].get_state() for name in self.obstacles], axis=0)
        # self.goal.set_state(self.goal.sample_goal(obstacle_pos))
        return self.get_state()

    def render(self):
        return self.frame
    
    def seed(self, seed):
        self.env.reset(seed=seed)